library(data.table)
library(grid)
library(gridExtra)
library(tidyverse)
library(latex2exp)

## data directory
dir.project <- './'
dir.data <- paste0(dir.project, 'data/')
dir.raw <- paste0(dir.data, 'raw/')
dir.generated <- paste0(dir.data, 'generated/')

## figures directory for output
dir.fig <-
    paste0(dir.project, 'doc/paper/figures/sensitivity/')

dir.fig.paper <- dir.fig

paper <- TRUE
source(paste0(dir.project, 'src/r/figures.R'))

style <- list(geom_tile(),
              scale_fill_gradient(low = "white", high = "black"),
              theme(axis.text.x = element_text(angle = 60, vjust = 5, hjust=0), legend.position="none"),
              xlab(""),
              ylab(""),
              scale_y_discrete(
                  limits=c(
                      'labor_income_share',
                      'R_share_in_labor_inc',
                      'M_share_in_labor_inc',
                      'emp_C',
                      'emp_R',
                      'emp_M',
                      'part',
                      'avg_wage_C',
                      'avg_wage_R',
                      'avg_wage_M',
                      'part_change_C',
                      'part_change_R',
                      'part_change_M',
                      'price_elast_robots',
                      'dec_eps_9',
                      'dec_eps_8',
                      'dec_eps_7',
                      'dec_eps_6',
                      'dec_eps_5',
                      'dec_eps_4',
                      'dec_eps_3',
                      'dec_eps_2',
                      'dec_eps_1'
              ),
              labels=c(
                  'price_elast_robots' = 'price-elast. robot adoption',
                  'dec_eps_1' = 'elast. dec. 1',
                  'dec_eps_2' = 'elast. dec. 2',
                  'dec_eps_3' = 'elast. dec. 3',
                  'dec_eps_4' = 'elast. dec. 4',
                  'dec_eps_5' = 'elast. dec. 5',
                  'dec_eps_6' = 'elast. dec. 6',
                  'dec_eps_7' = 'elast. dec. 7',
                  'dec_eps_8' = 'elast. dec. 8',
                  'dec_eps_9' = 'elast. dec. 9',
                  'emp_M' = 'empl. share manual',
                  'emp_R' = 'empl. share routine',
                  'emp_C' = 'empl. share cognitive',
                  'part_change_M' = 'part. change manual',
                  'part_change_R' = 'part. change routine',
                  'part_change_C' = 'part. change cognitive',
                  'avg_wage_M' = 'avg. wage manual',
                  'avg_wage_R' = 'avg. wage routine',
                  'avg_wage_C' = 'avg. wage cognitive',
                  'part' = 'participation rate',
                  'M_Share_in_labor_inc' = 'Share of labor inc. manual',
                  'R_Share_in_labor_inc' = 'Share of labor inc. routine',
                  'labor_income_share' = 'labor inc. share'
              )),
              scale_x_discrete(
                  position='top',
                  limits=c(
                      'alpha',
                      'rho',
                      'sigma',
                      'kappa_M',
                      'kappa_R',
                      'kappa_C',
                      'kappa_MR',
                      'gamma_E',
                      'gamma_C',
                      'gamma_M',
                      'gamma_B_M',
                      'gamma_B_R',
                      'gamma_B_C'
                  ),
                  labels=c(
                    'alpha' = unname(TeX('$\\alpha')),
                    'gamma_C' = unname(TeX('$\\gamma_{C}$')),
                    'gamma_B_C' = unname(TeX('$\\gamma_{B,C}$')),
                    'gamma_B_M' = unname(TeX('$\\gamma_{B,M}$')),
                    'gamma_B_R' = unname(TeX('$\\gamma_{B,R}$')),
                    'gamma_C' = unname(TeX('$\\gamma_{C}$')),
                    'gamma_E' = unname(TeX('$\\gamma_{E}$')),
                    'gamma_M' = unname(TeX('$\\gamma_{M}$')),
                    'kappa_M' = unname(TeX('$\\kappa_{M}$')),
                    'kappa_C' = unname(TeX('$\\kappa_{C}$')),
                    'kappa_MR' = unname(TeX('$\\kappa_{M,R}$')),
                    'kappa_R' = unname(TeX('$\\kappa_{R}$')),
                    'rho' = unname(TeX('$\\rho$')),
                    'sigma' = unname(TeX('$\\sigma$'))
                  )
              ))
  

df  <-
  fread(
    paste0(
      dir.generated,
      'matlab/simulation/20191107_nochoice_moments_sensitivity.csv'
    )
  )

df <- df %>% select(
  par_changed,
  price_elast_robots,
  part_change_M,
  part_change_R,
  part_change_C,
  avg_wage_M,
  avg_wage_R,
  avg_wage_C,
  part,
  emp_M,
  emp_R,
  emp_C,
  labor_income_share,
  M_share_in_labor_inc,
  R_share_in_labor_inc,
  starts_with('dec'))

# turn into absolute values
df <- df %>% mutate(across(!par_changed, abs))

## normalize by column
df.norm_col <- df %>% mutate(across(!par_changed, .fns=~./max(.)))

# wide to long
df.norm_col.long <- df.norm_col %>% pivot_longer(-par_changed)

p.norm_col <- ggplot(df.norm_col.long, aes(par_changed, name, fill=value))  + theme.serif + style +
  theme(axis.text.y=element_blank(), axis.ticks.y = element_blank())

## drill down into poorly identified parameters
df.norm_col_drill <- df %>% filter(par_changed %in% c('gamma_B_R', 'kappa_MR', 'kappa_R')) %>% mutate(across(!par_changed, .fns=~./max(.)))

# wide to long
df.norm_col_drill.long <- df.norm_col_drill %>% pivot_longer(-par_changed)

p.norm_col_drill <- ggplot(df.norm_col_drill.long, aes(par_changed, name, fill=value)) + style

## normalize by row
df.norm_row <- df
# compute max value across columns, row by row
df.norm_row[, 'max'] <- apply(df.norm_row[, 2:ncol(df.norm_row)], 1, max)

# normalize by max, row by row
df.norm_row <- df.norm_row %>% mutate(across(!par_changed, .fns=~./max))

# wide to long
df.norm_row$dec_eps_7 = NA # filling dec_eps_7 with NA since otherwise dominates everything
df.norm_row.long <- df.norm_row %>% select(-max) %>% pivot_longer(-par_changed)

p.norm_row <- ggplot(df.norm_row.long, aes(par_changed, name, fill=value)) + style + theme.serif


p.norm_row_no_dec <- ggplot(df.norm_row.long %>% filter(!grepl('dec', name)), aes(par_changed, name, fill=value)) + style



plots_combined <-
  grid.arrange(
    p.norm_row + ggtitle(TeX('A: normalized by parameter')),
    p.norm_col + ggtitle('B: normalized by moment'),
    ncol = 2,
    nrow = 1,
    widths=c(10,8)
  )

ggsave(
  paste0(dir.fig, "moments_sensitivity.pdf"),
  plot = plots_combined,
  width = 20,
  height = 14,
  units = "cm"
)

